13. 测试

测试训练后的模型

我将演示两种测试模型的好方法:使用测试数据和进行推理。第一种方法和在 CNN 课程中提到的方法相似。在 test_loader 中迭代测试数据,记录测试损失并根据模型预测正确的标签数计算准确率。

计算方法是查看输出的舍入值。输出是一个介于 0-1 之间的 S 型函数输出,所以舍入值将是一个整数,表示概率最大的标签;0 或 1。然后将预测标签与真实标签进行比较;如果匹配,则记录为标签正确的测试影评。

# Get test data loss and accuracy

test_losses = [] # track loss
num_correct = 0

# init hidden state
h = net.init_hidden(batch_size)

net.eval()
# iterate over test data
for inputs, labels in test_loader:

    # Creating new variables for the hidden state, otherwise
    # we'd backprop through the entire training history
    h = tuple([each.data for each in h])

    if(train_on_gpu):
        inputs, labels = inputs.cuda(), labels.cuda()

    # get predicted outputs
    output, h = net(inputs, h)

    # calculate loss
    test_loss = criterion(output.squeeze(), labels.float())
    test_losses.append(test_loss.item())

    # convert output probabilities to predicted class (0 or 1)
    pred = torch.round(output.squeeze())  # rounds to the nearest integer

    # compare predictions to true label
    correct_tensor = pred.eq(labels.float().view_as(pred))
    correct = np.squeeze(correct_tensor.numpy()) if not train_on_gpu else np.squeeze(correct_tensor.cpu().numpy())
    num_correct += np.sum(correct)


# -- stats! -- ##
# avg test loss
print("Test loss: {:.3f}".format(np.mean(test_losses)))

# accuracy over all test data
test_acc = num_correct/len(test_loader.dataset)
print("Test accuracy: {:.3f}".format(test_acc))

在下面输出平均测试损失和准确率,即分类正确的项目数除以测试数据总数。

测试损失是 0.516,准确率约为 81.1%

测试结果

测试结果

接下来是最后一个任务了。也就是定义 predict 函数,用模型对任何给定文本影评进行推理。

练习:对测试影评进行推理

你可以将 test_review 更改为任何文本。读一读并判断:是正面还是负面影评?然后看看模型能否正确预测。

练习:请编写一个 predict 函数,参数包括训练过的网络、纯文本影评和序列长度,然后针对正面或负面影评输出自定义描述性句子。

  • 你可以使用你已经定义过的任何函数,或定义帮助完成 predict 的辅助函数,但是参数只能包括训练过的网络、文本影评和序列长度。
def predict(net, test_review, sequence_length=200):
    ''' Prints out whether a give review is predicted to be 
        positive or negative in sentiment, using a trained model.

        params:
        net - A trained net 
        test_review - a review made of normal text and punctuation
        sequence_length - the padded length of a review
        '''


    # print custom response based on whether test_review is pos/neg

请试着自己完成这道练习,然后看看 solution。